


def get_parameter_number(net):
    return sum(p.numel() for p in net.parameters())

